import torch.nn as nn
from torch.autograd import Variable
class TripletLoss(object):
    def __init__(self, margin=None):
        self.margin = margin
        if margin is not None:
            self.ranking_loss = nn.MarginRankingLoss(margin=margin)
        else:
            self.ranking_loss = nn.SoftMarginLoss()

    def __call__(self, dist_ap, dist_an):
        '''
        Aegs:
        :param dist_ap: pytorch Variable, distance between anchor and positive sample,shape[N]
        :param dist_an: pytorch Variable, distance between anchor and negative sample, shape[N]
        :return: loss: pytorch Variable, with shape[1]
        '''
        y = Variable(dist_an.data.new().resize_as_(dist_an.data).fill_(1))
        if self.margin is not None:
            loss = self.ranking_loss(dist_an,dist_ap,y)
        else:
            loss = self.ranking_loss(dist_an-dist_ap,y)
        return loss